{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6df066b3",
   "metadata": {},
   "source": [
    "## 配置及导入必要的包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6f175e42",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-20T16:10:50.947211Z",
     "start_time": "2023-04-20T16:10:49.173659Z"
    }
   },
   "outputs": [],
   "source": [
    "from diffusers import StableDiffusionPipeline\n",
    "import torch\n",
    "from datetime import datetime\n",
    "import uuid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2a74461c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-20T16:10:51.769575Z",
     "start_time": "2023-04-20T16:10:51.761955Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'0421'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "today = datetime.now().strftime('%m%d')\n",
    "today"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a415c9de",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-20T16:10:53.059523Z",
     "start_time": "2023-04-20T16:10:53.048262Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'cuda'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1f4e1b1c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-20T16:11:07.811505Z",
     "start_time": "2023-04-20T16:11:07.803082Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.float32"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch_dtype = torch.float16 if 'cuda' in device else torch.float32\n",
    "torch_dtype = torch.float32\n",
    "torch_dtype"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f94d3790",
   "metadata": {},
   "source": [
    "## model & config & forward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "16d59f47",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-20T16:12:33.664632Z",
     "start_time": "2023-04-20T16:12:33.656269Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.877717391304348"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(11064-8)/(5896-8)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6a9ea43",
   "metadata": {},
   "source": [
    "- torch.float16\n",
    "    - load：3588MiB-8\n",
    "    - forward: 5896MiB-8\n",
    "- torch.float32\n",
    "    - load: 6268MiB-8\n",
    "    - forward: 11064MiB-8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c5304aaa",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-20T16:11:22.363491Z",
     "start_time": "2023-04-20T16:11:18.476113Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"id2label\"]` will be overriden.\n",
      "/home/whaow/anaconda3/envs/visgpt/lib/python3.8/site-packages/transformers/models/clip/feature_extraction_clip.py:28: FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use CLIPImageProcessor instead.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "pipe = StableDiffusionPipeline.from_pretrained('runwayml/stable-diffusion-v1-5', \n",
    "                                               torch_dtype=torch_dtype).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9dffb9a3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-20T16:08:17.728568Z",
     "start_time": "2023-04-20T16:08:17.722717Z"
    }
   },
   "outputs": [],
   "source": [
    "# 8\n",
    "# self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n",
    "# 64*8 = 512\n",
    "# height = height or self.unet.config.sample_size * self.vae_scale_factor\n",
    "# width = width or self.unet.config.sample_size * self.vae_scale_factor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "23818318",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-20T16:07:20.547562Z",
     "start_time": "2023-04-20T16:07:20.541012Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FrozenDict([('in_channels', 3),\n",
       "            ('out_channels', 3),\n",
       "            ('down_block_types',\n",
       "             ['DownEncoderBlock2D',\n",
       "              'DownEncoderBlock2D',\n",
       "              'DownEncoderBlock2D',\n",
       "              'DownEncoderBlock2D']),\n",
       "            ('up_block_types',\n",
       "             ['UpDecoderBlock2D',\n",
       "              'UpDecoderBlock2D',\n",
       "              'UpDecoderBlock2D',\n",
       "              'UpDecoderBlock2D']),\n",
       "            ('block_out_channels', [128, 256, 512, 512]),\n",
       "            ('layers_per_block', 2),\n",
       "            ('act_fn', 'silu'),\n",
       "            ('latent_channels', 4),\n",
       "            ('norm_num_groups', 32),\n",
       "            ('sample_size', 512),\n",
       "            ('scaling_factor', 0.18215),\n",
       "            ('_class_name', 'AutoencoderKL'),\n",
       "            ('_diffusers_version', '0.6.0'),\n",
       "            ('_name_or_path',\n",
       "             '/home/whaow/.cache/huggingface/hub/models--runwayml--stable-diffusion-v1-5/snapshots/39593d5650112b4cc580433f6b0435385882d819/vae')])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pipe.vae.config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2c3f7c2f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-20T16:07:23.268091Z",
     "start_time": "2023-04-20T16:07:23.258835Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FrozenDict([('sample_size', 64),\n",
       "            ('in_channels', 4),\n",
       "            ('out_channels', 4),\n",
       "            ('center_input_sample', False),\n",
       "            ('flip_sin_to_cos', True),\n",
       "            ('freq_shift', 0),\n",
       "            ('down_block_types',\n",
       "             ['CrossAttnDownBlock2D',\n",
       "              'CrossAttnDownBlock2D',\n",
       "              'CrossAttnDownBlock2D',\n",
       "              'DownBlock2D']),\n",
       "            ('mid_block_type', 'UNetMidBlock2DCrossAttn'),\n",
       "            ('up_block_types',\n",
       "             ['UpBlock2D',\n",
       "              'CrossAttnUpBlock2D',\n",
       "              'CrossAttnUpBlock2D',\n",
       "              'CrossAttnUpBlock2D']),\n",
       "            ('only_cross_attention', False),\n",
       "            ('block_out_channels', [320, 640, 1280, 1280]),\n",
       "            ('layers_per_block', 2),\n",
       "            ('downsample_padding', 1),\n",
       "            ('mid_block_scale_factor', 1),\n",
       "            ('act_fn', 'silu'),\n",
       "            ('norm_num_groups', 32),\n",
       "            ('norm_eps', 1e-05),\n",
       "            ('cross_attention_dim', 768),\n",
       "            ('encoder_hid_dim', None),\n",
       "            ('attention_head_dim', 8),\n",
       "            ('dual_cross_attention', False),\n",
       "            ('use_linear_projection', False),\n",
       "            ('class_embed_type', None),\n",
       "            ('num_class_embeds', None),\n",
       "            ('upcast_attention', False),\n",
       "            ('resnet_time_scale_shift', 'default'),\n",
       "            ('resnet_skip_time_act', False),\n",
       "            ('resnet_out_scale_factor', 1.0),\n",
       "            ('time_embedding_type', 'positional'),\n",
       "            ('time_embedding_act_fn', None),\n",
       "            ('timestep_post_act', None),\n",
       "            ('time_cond_proj_dim', None),\n",
       "            ('conv_in_kernel', 3),\n",
       "            ('conv_out_kernel', 3),\n",
       "            ('projection_class_embeddings_input_dim', None),\n",
       "            ('class_embeddings_concat', False),\n",
       "            ('mid_block_only_cross_attention', None),\n",
       "            ('cross_attention_norm', None),\n",
       "            ('_class_name', 'UNet2DConditionModel'),\n",
       "            ('_diffusers_version', '0.6.0'),\n",
       "            ('_name_or_path',\n",
       "             '/home/whaow/.cache/huggingface/hub/models--runwayml--stable-diffusion-v1-5/snapshots/39593d5650112b4cc580433f6b0435385882d819/unet')])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pipe.unet.config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "aa026555",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-20T16:11:34.502396Z",
     "start_time": "2023-04-20T16:11:34.499033Z"
    }
   },
   "outputs": [],
   "source": [
    "a_prompt = 'best quality, extremely detailed'\n",
    "n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \\\n",
    "                        'fewer digits, cropped, worst quality, low quality'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "41f1145b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-20T16:11:35.367667Z",
     "start_time": "2023-04-20T16:11:35.360012Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'a photo of an astronaut riding a horse on mars, best quality, extremely detailed'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# prompt = 'a cat in the living room'\n",
    "prompt = 'a photo of an astronaut riding a horse on mars'\n",
    "prompt = prompt + ', ' + a_prompt\n",
    "prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1b37506b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-20T16:11:43.442295Z",
     "start_time": "2023-04-20T16:11:36.560959Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "84af4eaeb55141158e25d88230b246d9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "img = pipe(prompt, negative_prompt=n_prompt).images[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "bb6c1156",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-20T16:11:54.758340Z",
     "start_time": "2023-04-20T16:11:54.671868Z"
    }
   },
   "outputs": [],
   "source": [
    "img.save(f'../results/image-{today}-{str(uuid.uuid4())[:8]}.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bbbee44",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "221px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
